public class GetIntersectionNode {
    public class ListNode {
        int val;
        ListNode next;
        ListNode(int x) {
            val = x;
            next = null;
        }
    }

    public ListNode getIntersectionNode(ListNode headA, ListNode headB) {
        if (headA == null || headB == null) {
            return null;
        }
        int lenA = 0;
        int lenB = 0;
        ListNode pl = headA;
        ListNode ps = headB;

        while (pl != null) {
            pl = pl.next;
            lenA++;
        }
        while (ps != null) {
            ps = ps.next;
            lenB++;
        }

        pl = headA;
        ps = headB;

        int len = lenA - lenB;
        if (len < 0) {
            pl = headB;
            ps = headA;
            len = lenB - lenA;
        }

        while (len != 0) {
            pl = pl.next;
            len--;
        }

        while (pl != null && ps != null && pl != ps) {
            pl = pl.next;
            ps = ps.next;
        }
        if (ps == null || pl == null) {
            return null;
        }
        return pl;
    }
}
